import torch
from torch.nn import functional as F


def vrbo(args, val_data_list, parameters, hparams, hparams_old, grad_x, grad_y, out_f, reg_f):
    data_list, labels_list = val_data_list[0]
    output = out_f(data_list[1], parameters)
    update_y = gradient_gy(args, labels_list[1], parameters, data_list[1], hparams, output, reg_f)
    update_y_old = gradient_gy(args, labels_list[1], parameters, data_list[1], hparams_old, output, reg_f)
    update_x = stocbio(parameters, hparams, val_data_list[0], args, out_f, reg_f)
    update_x_old = stocbio(parameters, hparams_old, val_data_list[0], args, out_f, reg_f)

    v_t = grad_x + update_x - update_x_old
    u_t = grad_y + update_y - update_y_old
    parameters_new = parameters - args.inner_lr * u_t
    for t in range(args.iterations):
        data_list, labels_list = val_data_list[t + 1]
        output = out_f(data_list[1], parameters_new)
        update_y = gradient_gy(args, labels_list[1], parameters_new, data_list[1], hparams, output, reg_f)
        output = out_f(data_list[1], parameters)
        update_y_old = gradient_gy(args, labels_list[1], parameters, data_list[1], hparams, output, reg_f)
        update_x = stocbio(parameters_new, hparams, val_data_list[t + 1], args, out_f, reg_f)
        update_x_old = stocbio(parameters, hparams, val_data_list[t + 1], args, out_f, reg_f)

        v_t = v_t + update_x - update_x_old
        u_t = u_t + update_y - update_y_old
        parameters = parameters_new
        parameters_new = parameters - args.inner_lr * u_t
    return parameters_new, v_t, u_t


def mstsa(prerious_update, eta, val_data_list, args, params, params_next, hparams, hparams_old, out_f, reg_f):
    grad1 = stocbio(params_next, hparams, val_data_list, args, out_f, reg_f)
    grad2 = stocbio(params, hparams_old, val_data_list, args, out_f, reg_f)
    outer_update = eta*grad1+(1-eta)*(prerious_update+grad1-grad2)
    return outer_update


def stocbio(params, hparams, val_data_list, args, out_f, reg_f):
        data_list, labels_list = val_data_list
        # Fy_gradient
        output = out_f(data_list[0], params)
        Fy_gradient = gradient_fy(args, labels_list[0], params, data_list[0], output)
        v_0 = torch.unsqueeze(torch.reshape(Fy_gradient, [-1]), 1).detach()

        # Hessian
        z_list = []
        output = out_f(data_list[1], params)
        Gy_gradient = gradient_gy(args, labels_list[1], params, data_list[1], hparams, output, reg_f) 

        G_gradient = torch.reshape(params, [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
        # G_gradient = torch.reshape(params[0], [-1]) - args.eta*torch.reshape(Gy_gradient, [-1])
        
        for _ in range(args.hessian_q):
        # for _ in range(args.K):
            Jacobian = torch.matmul(G_gradient, v_0)
            v_new = torch.autograd.grad(Jacobian, params, retain_graph=True)[0]
            v_0 = torch.unsqueeze(torch.reshape(v_new, [-1]), 1).detach()
            z_list.append(v_0)            
        v_Q = args.eta*v_0+torch.sum(torch.stack(z_list), dim=0)

        # Gyx_gradient
        output = out_f(data_list[2], params)
        Gy_gradient = gradient_gy(args, labels_list[2], params, data_list[2], hparams, output, reg_f)
        Gy_gradient = torch.reshape(Gy_gradient, [-1])
        Gyx_gradient = torch.autograd.grad(torch.matmul(Gy_gradient, v_Q.detach()), hparams, retain_graph=True)[0]
        outer_update = -Gyx_gradient 

        return outer_update

def gradient_fy(args, labels, params, data, output):
    loss = F.cross_entropy(output, labels)
    grad = torch.autograd.grad(loss, params)[0]
    return grad

def gradient_gy(args, labels_cp, params, data, hparams, output, reg_f):
    # For MNIST data-hyper cleaning experiments
    loss = F.cross_entropy(output, labels_cp, reduction='none')
    # For NewsGroup l2reg expriments
    # loss = F.cross_entropy(output, labels_cp)
    loss_regu = reg_f(params, hparams, loss)
    grad = torch.autograd.grad(loss_regu, params, create_graph=True)[0]
    return grad
